{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:35:02.869619Z",
     "start_time": "2019-11-05T03:35:00.471934Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div class=\"bk-root\">\n",
       "        <a href=\"https://bokeh.pydata.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n",
       "        <span id=\"1001\">Loading BokehJS ...</span>\n",
       "    </div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "\n",
       "(function(root) {\n",
       "  function now() {\n",
       "    return new Date();\n",
       "  }\n",
       "\n",
       "  var force = true;\n",
       "\n",
       "  if (typeof (root._bokeh_onload_callbacks) === \"undefined\" || force === true) {\n",
       "    root._bokeh_onload_callbacks = [];\n",
       "    root._bokeh_is_loading = undefined;\n",
       "  }\n",
       "\n",
       "  var JS_MIME_TYPE = 'application/javascript';\n",
       "  var HTML_MIME_TYPE = 'text/html';\n",
       "  var EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n",
       "  var CLASS_NAME = 'output_bokeh rendered_html';\n",
       "\n",
       "  /**\n",
       "   * Render data to the DOM node\n",
       "   */\n",
       "  function render(props, node) {\n",
       "    var script = document.createElement(\"script\");\n",
       "    node.appendChild(script);\n",
       "  }\n",
       "\n",
       "  /**\n",
       "   * Handle when an output is cleared or removed\n",
       "   */\n",
       "  function handleClearOutput(event, handle) {\n",
       "    var cell = handle.cell;\n",
       "\n",
       "    var id = cell.output_area._bokeh_element_id;\n",
       "    var server_id = cell.output_area._bokeh_server_id;\n",
       "    // Clean up Bokeh references\n",
       "    if (id != null && id in Bokeh.index) {\n",
       "      Bokeh.index[id].model.document.clear();\n",
       "      delete Bokeh.index[id];\n",
       "    }\n",
       "\n",
       "    if (server_id !== undefined) {\n",
       "      // Clean up Bokeh references\n",
       "      var cmd = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n",
       "      cell.notebook.kernel.execute(cmd, {\n",
       "        iopub: {\n",
       "          output: function(msg) {\n",
       "            var id = msg.content.text.trim();\n",
       "            if (id in Bokeh.index) {\n",
       "              Bokeh.index[id].model.document.clear();\n",
       "              delete Bokeh.index[id];\n",
       "            }\n",
       "          }\n",
       "        }\n",
       "      });\n",
       "      // Destroy server and session\n",
       "      var cmd = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n",
       "      cell.notebook.kernel.execute(cmd);\n",
       "    }\n",
       "  }\n",
       "\n",
       "  /**\n",
       "   * Handle when a new output is added\n",
       "   */\n",
       "  function handleAddOutput(event, handle) {\n",
       "    var output_area = handle.output_area;\n",
       "    var output = handle.output;\n",
       "\n",
       "    // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n",
       "    if ((output.output_type != \"display_data\") || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n",
       "      return\n",
       "    }\n",
       "\n",
       "    var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n",
       "\n",
       "    if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n",
       "      toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n",
       "      // store reference to embed id on output_area\n",
       "      output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n",
       "    }\n",
       "    if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n",
       "      var bk_div = document.createElement(\"div\");\n",
       "      bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n",
       "      var script_attrs = bk_div.children[0].attributes;\n",
       "      for (var i = 0; i < script_attrs.length; i++) {\n",
       "        toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n",
       "      }\n",
       "      // store reference to server id on output_area\n",
       "      output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n",
       "    }\n",
       "  }\n",
       "\n",
       "  function register_renderer(events, OutputArea) {\n",
       "\n",
       "    function append_mime(data, metadata, element) {\n",
       "      // create a DOM node to render to\n",
       "      var toinsert = this.create_output_subarea(\n",
       "        metadata,\n",
       "        CLASS_NAME,\n",
       "        EXEC_MIME_TYPE\n",
       "      );\n",
       "      this.keyboard_manager.register_events(toinsert);\n",
       "      // Render to node\n",
       "      var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n",
       "      render(props, toinsert[toinsert.length - 1]);\n",
       "      element.append(toinsert);\n",
       "      return toinsert\n",
       "    }\n",
       "\n",
       "    /* Handle when an output is cleared or removed */\n",
       "    events.on('clear_output.CodeCell', handleClearOutput);\n",
       "    events.on('delete.Cell', handleClearOutput);\n",
       "\n",
       "    /* Handle when a new output is added */\n",
       "    events.on('output_added.OutputArea', handleAddOutput);\n",
       "\n",
       "    /**\n",
       "     * Register the mime type and append_mime function with output_area\n",
       "     */\n",
       "    OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n",
       "      /* Is output safe? */\n",
       "      safe: true,\n",
       "      /* Index of renderer in `output_area.display_order` */\n",
       "      index: 0\n",
       "    });\n",
       "  }\n",
       "\n",
       "  // register the mime type if in Jupyter Notebook environment and previously unregistered\n",
       "  if (root.Jupyter !== undefined) {\n",
       "    var events = require('base/js/events');\n",
       "    var OutputArea = require('notebook/js/outputarea').OutputArea;\n",
       "\n",
       "    if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n",
       "      register_renderer(events, OutputArea);\n",
       "    }\n",
       "  }\n",
       "\n",
       "  \n",
       "  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n",
       "    root._bokeh_timeout = Date.now() + 5000;\n",
       "    root._bokeh_failed_load = false;\n",
       "  }\n",
       "\n",
       "  var NB_LOAD_WARNING = {'data': {'text/html':\n",
       "     \"<div style='background-color: #fdd'>\\n\"+\n",
       "     \"<p>\\n\"+\n",
       "     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n",
       "     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n",
       "     \"</p>\\n\"+\n",
       "     \"<ul>\\n\"+\n",
       "     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n",
       "     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n",
       "     \"</ul>\\n\"+\n",
       "     \"<code>\\n\"+\n",
       "     \"from bokeh.resources import INLINE\\n\"+\n",
       "     \"output_notebook(resources=INLINE)\\n\"+\n",
       "     \"</code>\\n\"+\n",
       "     \"</div>\"}};\n",
       "\n",
       "  function display_loaded() {\n",
       "    var el = document.getElementById(\"1001\");\n",
       "    if (el != null) {\n",
       "      el.textContent = \"BokehJS is loading...\";\n",
       "    }\n",
       "    if (root.Bokeh !== undefined) {\n",
       "      if (el != null) {\n",
       "        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n",
       "      }\n",
       "    } else if (Date.now() < root._bokeh_timeout) {\n",
       "      setTimeout(display_loaded, 100)\n",
       "    }\n",
       "  }\n",
       "\n",
       "\n",
       "  function run_callbacks() {\n",
       "    try {\n",
       "      root._bokeh_onload_callbacks.forEach(function(callback) { callback() });\n",
       "    }\n",
       "    finally {\n",
       "      delete root._bokeh_onload_callbacks\n",
       "    }\n",
       "    console.info(\"Bokeh: all callbacks have finished\");\n",
       "  }\n",
       "\n",
       "  function load_libs(js_urls, callback) {\n",
       "    root._bokeh_onload_callbacks.push(callback);\n",
       "    if (root._bokeh_is_loading > 0) {\n",
       "      console.log(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n",
       "      return null;\n",
       "    }\n",
       "    if (js_urls == null || js_urls.length === 0) {\n",
       "      run_callbacks();\n",
       "      return null;\n",
       "    }\n",
       "    console.log(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n",
       "    root._bokeh_is_loading = js_urls.length;\n",
       "    for (var i = 0; i < js_urls.length; i++) {\n",
       "      var url = js_urls[i];\n",
       "      var s = document.createElement('script');\n",
       "      s.src = url;\n",
       "      s.async = false;\n",
       "      s.onreadystatechange = s.onload = function() {\n",
       "        root._bokeh_is_loading--;\n",
       "        if (root._bokeh_is_loading === 0) {\n",
       "          console.log(\"Bokeh: all BokehJS libraries loaded\");\n",
       "          run_callbacks()\n",
       "        }\n",
       "      };\n",
       "      s.onerror = function() {\n",
       "        console.warn(\"failed to load library \" + url);\n",
       "      };\n",
       "      console.log(\"Bokeh: injecting script tag for BokehJS library: \", url);\n",
       "      document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
       "    }\n",
       "  };var element = document.getElementById(\"1001\");\n",
       "  if (element == null) {\n",
       "    console.log(\"Bokeh: ERROR: autoload.js configured with elementid '1001' but no matching script tag was found. \")\n",
       "    return false;\n",
       "  }\n",
       "\n",
       "  var js_urls = [\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-gl-1.0.0.min.js\"];\n",
       "\n",
       "  var inline_js = [\n",
       "    function(Bokeh) {\n",
       "      Bokeh.set_log_level(\"info\");\n",
       "    },\n",
       "    \n",
       "    function(Bokeh) {\n",
       "      \n",
       "    },\n",
       "    function(Bokeh) {\n",
       "      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n",
       "      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n",
       "      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n",
       "      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n",
       "      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n",
       "      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n",
       "    }\n",
       "  ];\n",
       "\n",
       "  function run_inline_js() {\n",
       "    \n",
       "    if ((root.Bokeh !== undefined) || (force === true)) {\n",
       "      for (var i = 0; i < inline_js.length; i++) {\n",
       "        inline_js[i].call(root, root.Bokeh);\n",
       "      }if (force === true) {\n",
       "        display_loaded();\n",
       "      }} else if (Date.now() < root._bokeh_timeout) {\n",
       "      setTimeout(run_inline_js, 100);\n",
       "    } else if (!root._bokeh_failed_load) {\n",
       "      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n",
       "      root._bokeh_failed_load = true;\n",
       "    } else if (force !== true) {\n",
       "      var cell = $(document.getElementById(\"1001\")).parents('.cell').data().cell;\n",
       "      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n",
       "    }\n",
       "\n",
       "  }\n",
       "\n",
       "  if (root._bokeh_is_loading === 0) {\n",
       "    console.log(\"Bokeh: BokehJS loaded, going straight to plotting\");\n",
       "    run_inline_js();\n",
       "  } else {\n",
       "    load_libs(js_urls, function() {\n",
       "      console.log(\"Bokeh: BokehJS plotting callback run at\", now());\n",
       "      run_inline_js();\n",
       "    });\n",
       "  }\n",
       "}(window));"
      ],
      "application/vnd.bokehjs_load.v0+json": "\n(function(root) {\n  function now() {\n    return new Date();\n  }\n\n  var force = true;\n\n  if (typeof (root._bokeh_onload_callbacks) === \"undefined\" || force === true) {\n    root._bokeh_onload_callbacks = [];\n    root._bokeh_is_loading = undefined;\n  }\n\n  \n\n  \n  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n    root._bokeh_timeout = Date.now() + 5000;\n    root._bokeh_failed_load = false;\n  }\n\n  var NB_LOAD_WARNING = {'data': {'text/html':\n     \"<div style='background-color: #fdd'>\\n\"+\n     \"<p>\\n\"+\n     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n     \"</p>\\n\"+\n     \"<ul>\\n\"+\n     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n     \"</ul>\\n\"+\n     \"<code>\\n\"+\n     \"from bokeh.resources import INLINE\\n\"+\n     \"output_notebook(resources=INLINE)\\n\"+\n     \"</code>\\n\"+\n     \"</div>\"}};\n\n  function display_loaded() {\n    var el = document.getElementById(\"1001\");\n    if (el != null) {\n      el.textContent = \"BokehJS is loading...\";\n    }\n    if (root.Bokeh !== undefined) {\n      if (el != null) {\n        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n      }\n    } else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(display_loaded, 100)\n    }\n  }\n\n\n  function run_callbacks() {\n    try {\n      root._bokeh_onload_callbacks.forEach(function(callback) { callback() });\n    }\n    finally {\n      delete root._bokeh_onload_callbacks\n    }\n    console.info(\"Bokeh: all callbacks have finished\");\n  }\n\n  function load_libs(js_urls, callback) {\n    root._bokeh_onload_callbacks.push(callback);\n    if (root._bokeh_is_loading > 0) {\n      console.log(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n      return null;\n    }\n    if (js_urls == null || js_urls.length === 0) {\n      run_callbacks();\n      return null;\n    }\n    console.log(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n    root._bokeh_is_loading = js_urls.length;\n    for (var i = 0; i < js_urls.length; i++) {\n      var url = js_urls[i];\n      var s = document.createElement('script');\n      s.src = url;\n      s.async = false;\n      s.onreadystatechange = s.onload = function() {\n        root._bokeh_is_loading--;\n        if (root._bokeh_is_loading === 0) {\n          console.log(\"Bokeh: all BokehJS libraries loaded\");\n          run_callbacks()\n        }\n      };\n      s.onerror = function() {\n        console.warn(\"failed to load library \" + url);\n      };\n      console.log(\"Bokeh: injecting script tag for BokehJS library: \", url);\n      document.getElementsByTagName(\"head\")[0].appendChild(s);\n    }\n  };var element = document.getElementById(\"1001\");\n  if (element == null) {\n    console.log(\"Bokeh: ERROR: autoload.js configured with elementid '1001' but no matching script tag was found. \")\n    return false;\n  }\n\n  var js_urls = [\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-gl-1.0.0.min.js\"];\n\n  var inline_js = [\n    function(Bokeh) {\n      Bokeh.set_log_level(\"info\");\n    },\n    \n    function(Bokeh) {\n      \n    },\n    function(Bokeh) {\n      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n    }\n  ];\n\n  function run_inline_js() {\n    \n    if ((root.Bokeh !== undefined) || (force === true)) {\n      for (var i = 0; i < inline_js.length; i++) {\n        inline_js[i].call(root, root.Bokeh);\n      }if (force === true) {\n        display_loaded();\n      }} else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(run_inline_js, 100);\n    } else if (!root._bokeh_failed_load) {\n      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n      root._bokeh_failed_load = true;\n    } else if (force !== true) {\n      var cell = $(document.getElementById(\"1001\")).parents('.cell').data().cell;\n      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n    }\n\n  }\n\n  if (root._bokeh_is_loading === 0) {\n    console.log(\"Bokeh: BokehJS loaded, going straight to plotting\");\n    run_inline_js();\n  } else {\n    load_libs(js_urls, function() {\n      console.log(\"Bokeh: BokehJS plotting callback run at\", now());\n      run_inline_js();\n    });\n  }\n}(window));"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from bokeh.io import show, output_notebook\n",
    "from bokeh.plotting import figure, gridplot\n",
    "from bokeh.models import LinearAxis, Range1d\n",
    "output_notebook()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPU\n",
    "指定使用的 GPU 编号。  \n",
    "`watch -n 1 nvidia-smi` 实时查看 GPU 的运行状态。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:35:03.151390Z",
     "start_time": "2019-11-05T03:35:02.871941Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.set_device(\"cuda:1\")\n",
    "torch.cuda.current_device()\n",
    "# device = torch.device(\"cuda:5\")\n",
    "# xxx.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-02T15:34:18.036850Z",
     "start_time": "2019-11-02T15:34:18.032937Z"
    }
   },
   "source": [
    "### Data\n",
    "通过`torchvision.datasets`下载`MNIST`数据。  \n",
    "训练集：`train=True`  \n",
    "测试集：`train=False`  \n",
    "常用的还有`torchvision.datasets.ImageFolder()`，按文件夹取图片。  \n",
    "\n",
    "`torchvision.transforms`可以对图片做处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:35:03.332972Z",
     "start_time": "2019-11-05T03:35:03.153726Z"
    }
   },
   "outputs": [],
   "source": [
    "train_dataset = dsets.MNIST(root='../dataset', train=True, transform=transforms.ToTensor(), download=True)\n",
    "test_dataset = dsets.MNIST(root='../dataset', train=False, transform=transforms.ToTensor(), download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:35:03.389756Z",
     "start_time": "2019-11-05T03:35:03.335530Z"
    }
   },
   "outputs": [],
   "source": [
    "batch_size = 50\n",
    "train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model\n",
    "构造 CNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:35:03.487725Z",
     "start_time": "2019-11-05T03:35:03.391610Z"
    }
   },
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2)\n",
    "            )\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2)\n",
    "            )\n",
    "        self.linear = nn.Linear(32*7*7, 10)\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.conv2(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        return self.linear(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:35:12.792864Z",
     "start_time": "2019-11-05T03:35:03.489810Z"
    }
   },
   "outputs": [],
   "source": [
    "lrate = 0.001\n",
    "epochs = 3\n",
    "\n",
    "model = CNN().double().cuda()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optim = torch.optim.Adam(model.parameters(), lr=lrate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:35:12.804924Z",
     "start_time": "2019-11-05T03:35:12.795979Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv1): Sequential(\n",
       "    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "    (1): ReLU()\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv2): Sequential(\n",
       "    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "    (1): ReLU()\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (linear): Linear(in_features=1568, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意：每次反向传播的时候都需要将参数的梯度归零。  \n",
    "`optim.step()`则在每个`Variable`的`grad`都被计算出来后，更新每个`Variable`的数值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在每次训练中都用`train_loader`中的一个`batch`作为训练数据。  \n",
    "`Tensor.cuda()` 每个 `batch` 在实际使用之前，都先移入 `GPU` 后进行计算。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:36:25.714015Z",
     "start_time": "2019-11-05T03:35:12.816041Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "  <div class=\"bk-root\" id=\"06864b49-76b3-4bdf-a06b-a109aede11eb\"></div>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "(function(root) {\n",
       "  function embed_document(root) {\n",
       "    \n",
       "  var docs_json = {\"8cbcb812-7292-49e0-8b20-3488d2279f76\":{\"roots\":{\"references\":[{\"attributes\":{\"below\":[{\"id\":\"1011\",\"type\":\"LinearAxis\"}],\"left\":[{\"id\":\"1016\",\"type\":\"LinearAxis\"}],\"renderers\":[{\"id\":\"1011\",\"type\":\"LinearAxis\"},{\"id\":\"1015\",\"type\":\"Grid\"},{\"id\":\"1016\",\"type\":\"LinearAxis\"},{\"id\":\"1020\",\"type\":\"Grid\"},{\"id\":\"1029\",\"type\":\"BoxAnnotation\"},{\"id\":\"1039\",\"type\":\"GlyphRenderer\"}],\"title\":{\"id\":\"1041\",\"type\":\"Title\"},\"toolbar\":{\"id\":\"1027\",\"type\":\"Toolbar\"},\"x_range\":{\"id\":\"1003\",\"type\":\"DataRange1d\"},\"x_scale\":{\"id\":\"1007\",\"type\":\"LinearScale\"},\"y_range\":{\"id\":\"1005\",\"type\":\"DataRange1d\"},\"y_scale\":{\"id\":\"1009\",\"type\":\"LinearScale\"}},\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},{\"attributes\":{\"plot\":null,\"text\":\"\"},\"id\":\"1041\",\"type\":\"Title\"},{\"attributes\":{\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1012\",\"type\":\"BasicTicker\"}},\"id\":\"1015\",\"type\":\"Grid\"},{\"attributes\":{},\"id\":\"1043\",\"type\":\"BasicTickFormatter\"},{\"attributes\":{\"formatter\":{\"id\":\"1045\",\"type\":\"BasicTickFormatter\"},\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1017\",\"type\":\"BasicTicker\"}},\"id\":\"1016\",\"type\":\"LinearAxis\"},{\"attributes\":{},\"id\":\"1045\",\"type\":\"BasicTickFormatter\"},{\"attributes\":{},\"id\":\"1017\",\"type\":\"BasicTicker\"},{\"attributes\":{},\"id\":\"1048\",\"type\":\"Selection\"},{\"attributes\":{\"dimension\":1,\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1017\",\"type\":\"BasicTicker\"}},\"id\":\"1020\",\"type\":\"Grid\"},{\"attributes\":{},\"id\":\"1049\",\"type\":\"UnionRenderers\"},{\"attributes\":{\"callback\":null,\"data\":{\"x\":[0,1,2],\"y\":[0.03337550947159823,0.07055367252357084,0.09845329805324582]},\"selected\":{\"id\":\"1048\",\"type\":\"Selection\"},\"selection_policy\":{\"id\":\"1049\",\"type\":\"UnionRenderers\"}},\"id\":\"1036\",\"type\":\"ColumnDataSource\"},{\"attributes\":{\"line_color\":\"#1f77b4\",\"x\":{\"field\":\"x\"},\"y\":{\"field\":\"y\"}},\"id\":\"1037\",\"type\":\"Line\"},{\"attributes\":{},\"id\":\"1021\",\"type\":\"PanTool\"},{\"attributes\":{},\"id\":\"1022\",\"type\":\"WheelZoomTool\"},{\"attributes\":{\"callback\":null},\"id\":\"1003\",\"type\":\"DataRange1d\"},{\"attributes\":{\"overlay\":{\"id\":\"1029\",\"type\":\"BoxAnnotation\"}},\"id\":\"1023\",\"type\":\"BoxZoomTool\"},{\"attributes\":{},\"id\":\"1024\",\"type\":\"SaveTool\"},{\"attributes\":{},\"id\":\"1025\",\"type\":\"ResetTool\"},{\"attributes\":{},\"id\":\"1026\",\"type\":\"HelpTool\"},{\"attributes\":{\"callback\":null},\"id\":\"1005\",\"type\":\"DataRange1d\"},{\"attributes\":{\"active_drag\":\"auto\",\"active_inspect\":\"auto\",\"active_multi\":null,\"active_scroll\":\"auto\",\"active_tap\":\"auto\",\"tools\":[{\"id\":\"1021\",\"type\":\"PanTool\"},{\"id\":\"1022\",\"type\":\"WheelZoomTool\"},{\"id\":\"1023\",\"type\":\"BoxZoomTool\"},{\"id\":\"1024\",\"type\":\"SaveTool\"},{\"id\":\"1025\",\"type\":\"ResetTool\"},{\"id\":\"1026\",\"type\":\"HelpTool\"}]},\"id\":\"1027\",\"type\":\"Toolbar\"},{\"attributes\":{\"line_alpha\":0.1,\"line_color\":\"#1f77b4\",\"x\":{\"field\":\"x\"},\"y\":{\"field\":\"y\"}},\"id\":\"1038\",\"type\":\"Line\"},{\"attributes\":{},\"id\":\"1007\",\"type\":\"LinearScale\"},{\"attributes\":{\"bottom_units\":\"screen\",\"fill_alpha\":{\"value\":0.5},\"fill_color\":{\"value\":\"lightgrey\"},\"left_units\":\"screen\",\"level\":\"overlay\",\"line_alpha\":{\"value\":1.0},\"line_color\":{\"value\":\"black\"},\"line_dash\":[4,4],\"line_width\":{\"value\":2},\"plot\":null,\"render_mode\":\"css\",\"right_units\":\"screen\",\"top_units\":\"screen\"},\"id\":\"1029\",\"type\":\"BoxAnnotation\"},{\"attributes\":{},\"id\":\"1009\",\"type\":\"LinearScale\"},{\"attributes\":{\"data_source\":{\"id\":\"1036\",\"type\":\"ColumnDataSource\"},\"glyph\":{\"id\":\"1037\",\"type\":\"Line\"},\"hover_glyph\":null,\"muted_glyph\":null,\"nonselection_glyph\":{\"id\":\"1038\",\"type\":\"Line\"},\"selection_glyph\":null,\"view\":{\"id\":\"1040\",\"type\":\"CDSView\"}},\"id\":\"1039\",\"type\":\"GlyphRenderer\"},{\"attributes\":{\"formatter\":{\"id\":\"1043\",\"type\":\"BasicTickFormatter\"},\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1012\",\"type\":\"BasicTicker\"}},\"id\":\"1011\",\"type\":\"LinearAxis\"},{\"attributes\":{\"source\":{\"id\":\"1036\",\"type\":\"ColumnDataSource\"}},\"id\":\"1040\",\"type\":\"CDSView\"},{\"attributes\":{},\"id\":\"1012\",\"type\":\"BasicTicker\"}],\"root_ids\":[\"1002\"]},\"title\":\"Bokeh Application\",\"version\":\"1.0.0\"}};\n",
       "  var render_items = [{\"docid\":\"8cbcb812-7292-49e0-8b20-3488d2279f76\",\"roots\":{\"1002\":\"06864b49-76b3-4bdf-a06b-a109aede11eb\"}}];\n",
       "  root.Bokeh.embed.embed_items_notebook(docs_json, render_items);\n",
       "\n",
       "  }\n",
       "  if (root.Bokeh !== undefined) {\n",
       "    embed_document(root);\n",
       "  } else {\n",
       "    var attempts = 0;\n",
       "    var timer = setInterval(function(root) {\n",
       "      if (root.Bokeh !== undefined) {\n",
       "        embed_document(root);\n",
       "        clearInterval(timer);\n",
       "      }\n",
       "      attempts++;\n",
       "      if (attempts > 100) {\n",
       "        console.log(\"Bokeh: ERROR: Unable to run BokehJS code because BokehJS library is missing\");\n",
       "        clearInterval(timer);\n",
       "      }\n",
       "    }, 10, root)\n",
       "  }\n",
       "})(window);"
      ],
      "application/vnd.bokehjs_exec.v0+json": ""
     },
     "metadata": {
      "application/vnd.bokehjs_exec.v0+json": {
       "id": "1002"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = []\n",
    "for e in range(epochs):\n",
    "    for i, (inputs, targets) in enumerate(train_loader):\n",
    "        inputs = inputs.double().cuda()\n",
    "        targets = targets.cuda()\n",
    "        optim.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "    result.append(float(loss))\n",
    "    \n",
    "fig = figure()\n",
    "fig.line(range(len(result)), result)\n",
    "show(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`re = torch.max(Tensor,dim)`, 返回的re为一个二维向量，其中`re[0]`为最大值的`Tensor`，re[1]为最大值对应的`index`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:36:39.925803Z",
     "start_time": "2019-11-05T03:36:25.722691Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the model on the 10000 test images: 98.80 %\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "wrong_count = 0\n",
    "wrong_classify = []\n",
    "for i, (inputs, targets) in enumerate(test_loader):\n",
    "    targets = targets.cuda()\n",
    "    inputs = inputs.double().cuda()\n",
    "    outputs = model(inputs)\n",
    "    _, preds = torch.max(outputs.data, 1)\n",
    "    total += len(outputs)\n",
    "    correct += (preds == targets).sum()\n",
    "    if wrong_count < 5 and preds != targets:\n",
    "        wrong_classify.append([inputs.data, preds, targets])\n",
    "        wrong_count += 1\n",
    "accuracy = 100 * correct.double() / total\n",
    "print('Accuracy of the model on the 10000 test images: %.2f %%' % (accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:36:41.014500Z",
     "start_time": "2019-11-05T03:36:39.927549Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2oAAAC/CAYAAACPMC8KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XmYVdWV/vF3VTEpyuwAiCAqDqhBQU2c2rRJO0RF7Rg1iVNUtB1jowaNv2h3YgeTqK1xCoqCxrkdQDFGQzAaZ1BEBBFUBBQVR1ARqKr9++NekpK1izpVdzrn1vfzPPeBeusM+1StOnX2vfesshCCAAAAAADpUVPpAQAAAAAAvo6JGgAAAACkDBM1AAAAAEgZJmoAAAAAkDJM1AAAAAAgZZioAQAAAEDKMFFrw8zscTM7sdLjAJIys3Fm9qtKjwNIippF1nBtgKyp5vNsm5qomdl8M/tOmfc5xMymmdmX+X+HlHP/hTCzDczsdjP71Mw+MbPbKj2mtqbcNWtmg8xsgpktMbOPzezPZrZVufZfCDP7kZl93ujxpZkFMxta6bG1JRU6zx5kZjPz3/enzWzbcu6/tczsm2b2WP5nbYmZ3WNmvSs9rraGa4OW4dqg8rg2SC7r1wZtaqLWHDNrV+TtdZA0QdIfJXWXNF7ShHxeVMUee959kt6T1F/ShpJ+V4J9oAAl+L53kzRR0laSNpL0vHI1XHTFHnsI4bYQwnqrH5JOlfSmpBeLuR8UpgTn2S0l3SbpFOXq90FJE0txTizBNrtLGiNpgHLn2WWSbi7yPlAgrg0crg1SjmuDf8r8tUEIoU08JN0qqUHSckmfSzpPuV+OQdIJkhZIekLS3pIWrbHufEnfyf+/RtIoSW9I+kjS3ZJ6NLHPf5P0jiRrlC2QtF/CMQdJZypXUB9K+q2kmvznjpP0lKQrJH0s6Vf5/CeSZkv6RNKfJfVvtL3vSnpN0meSrpb0N0knrmXs8yXVVvp711YflajZyBh65PfXM+Hy8yWdL2lWvgZvltQp/7m9JS2S9DPlfsnfms8PlDRd0qeSnpa0Q6Pt7ajcyXSZpLsk3bm61hOMZYqkiyr9fWxLjwqdZ0+XNKnRxzX5/e+TwZrdSdKySn8f29KjQjXLtQGPTNVsZAxcG5Tp0WZeUQshHK1c8R4UcrPq3zT69L9I2kbSvgk2daakQ/Lr9FGu4K5Z/Ukzm2FmP8x/OFjSjJCvjLwZ+TypQyUNU+4X+HDlTrar7arciXpDSZeY2SGSLpB0mKQNJD0p6Y78uHpJulfShZJ6KfeDuXujcW+afxvDpvnom5LmSBpvZh+Z2Qtm9i8tGDcKVKGaXdNekt4LIXzUgqH/KD+uzSUNUq7mVttYuRN8f0kjzGwnSTdJOllST0l/UO7VkI75Z5cfUO6XUg9J90j698Y7ytfsHmsOwMz658d+SwvGjQJVqGYt/9AaH2/XgqFXvGbz9pL0agvGjQJxbcC1QdZwbdDGrg0qPVMs50ONnknIfzxAuWcEBjbK9tban4GYrUbP1ErqLWmVpHaR/f0/SXeukd0m6eKE4w1q9Aybci/XTs7//zhJC9ZY/k+STmj0cY2kL5Ur/GMkPdvoc6bcMxhNPWs2Rv98dqa9pCOVe1ajV6W/j23pUe6aXWMbmyj3rO9RLRzvKY0+PkDSG43GuVL5Z9Hy2XWSfrnGNuYo94tjL0nv6uvPOj+tBM+a5X/2Hq/0968tPipwnt1a0hf5bXbIf+8bJJ3fgvGmoWZ3UO4VkD0r/T1sa48K1CzXBjwyVbNrbINrgzI+2swras1Y2IJl+0u6Pz9b/1S5Qq9X7j27a/pcUpc1si7KvVTbmrG9rdyzHrHPrR7blY3G9rFyJ92++fX+sXzIVezajnu5pPkhhLEhhFUhhDvzy+++lnVQPqWqWUm5m8UlPSrp2hDCHQWMbc2aXRJC+GqNsY1cPbb8+Prl1+kj6Z18rTbeXhLHKHffB9KjJDUbQnhN0rHKvWVrsXKvCsxS7mKzNWMre82a2RbKXUyfFUJ4sgXjRmlxbeBxbZBuXBusXeauDdraRC0kyL+QtO7qD8ysVrm3Cqy2UNL+IYRujR6dQgjvRLb7qqQdzKzx23J2UMve2tKv0f83Ve5ZhNi4V4/t5DXGtk4I4WnlLmD+sa38mPqpaTMi20f5lbtmZWbdlTsRTwwhXNKKMbe0Zi9ZY2zr5n8BLJbUd42fn03VDDPbXbkT+f+1YuwoXNlrNoTwfyGE7UIIPSVdpNwv+RdaMOaK1Wz+rTh/Ue7Z41tbMGYUD9cG/zwurg2ygWuDNnJt0NYmau9LGtjMMq9L6mRm3zOz9sq9h7Zjo89fr9x7vvtL/2hTO7yJbT2u3LMTZ+bfV3t6Pv9rft3jzGx+M+M518y6m1k/SWcpd9NkU66XdL6ZDc5vv6uZHZ7/3CRJg83ssHxHnTOVe09wU+6X1N3MjjWzWjP7vnLPvj3VzHhRXGWtWTProtyN5k+FEEZFPr+3mTX3S/o0M9vEzHood1/E2mr2BkmnmNmultM5fxzrS3pGUp1yPz/tzOwwSbs0s28p9+rKvSGEljw7jeIp93lWZjY0f57aQLl7GR7Mv9KW6po1s77K/T64JoRwfTNjROlwbcC1QdZwbdBGrg3a2kTt15IuzL+Mek5sgRDCZ8q93/tG5d6D+4W+/haaK5VrUfqomS2T9KxyN+5KkszsVTP7UX5bK5W7UfMY5d7D/RNJh+RzKffsQnMntwmSpinX+WaSpLFNLRhCuF/SpZLuNLOlkmZK2j//uQ8lHS5ptHLdfbZsvG/L3TD8ueVvGA4hfCzpYEnnKNcJapSk4fntoHzKWrPK3aC+s6Tj7et/d2T1s1X9lDtJrs3tyj3r9mb+0eQfoQwhTJV0knJvW/tE0jzl7rFY/fNzWP7jTyQdoVxb6H/Ij23PRh93kvQDZeytDVWm3DW7evlPlbuH4VPlamq1NNfsicpdbF3U+OetmbGi+Lg24Noga7g2aCPXBvb1t3iinMzsUeXuSZjdxOeDpC1DCPPKOzIgzsxulHRPCOHPTXx+vnI3of+lrAMDmkDNImu4NkDWcJ4tnVL8IUQkFEL4t0qPAWiJEMKJlR4D0BLULLKGawNkDefZ0mlrb30EAAAAgNTjrY8AAAAAkDK8ogYAAAAAKVPQPWpmtp9yXWNqJd0YQhi9tuU7WMfQSZ0L2SXasK/0hVaGFdb8kk2jZlFOxahZqWV1S82iUMv0yYchhA2aX7Jp1CzKqdw1K1G3KEzS64NWT9Qs94fzrpH0XeXafb5gZhNDCLOaWqeTOmtX26e1u0Qb91yYXND61CzKrdCalVpet9QsCvWX8H9vF7I+NYtyK3fNStQtCpP0+qCQtz7uImleCOHN/N80uFNSk3+QFEgBahZZRN0ia6hZZA01i1QqZKLWV9LCRh8vymdfY2YjzGyqmU1dpRUF7A4oGDWLLGq2bqlZpAw1i6zh+gCpVMhELfa+StdCMoQwJoQwLIQwrL06FrA7oGDULLKo2bqlZpEy1CyyhusDpFIhE7VFkvo1+ngTSe8WNhygpKhZZBF1i6yhZpE11CxSqZCJ2guStjSzzcysg6QjJU0szrCAkqBmkUXULbKGmkXWULNIpVZ3fQwh1JnZ6ZL+rFwr05tCCK8WbWRAkVGzyCLqFllDzSJrqFmkVUF/Ry2E8LCkh4s0FqDkqFlkEXWLrKFmkTXULNKokLc+AgAAAABKgIkaAAAAAKQMEzUAAAAASBkmagAAAACQMkzUAAAAACBlmKgBAAAAQMowUQMAAACAlGGiBgAAAAApw0QNAAAAAFKGiRoAAAAApAwTNQAAAABIGSZqAAAAAJAy7So9gCyrHbyVy975Tk+XbXzl0+UYDgC0XTW1PurU0WULzxgSXb3fnz724bwFLgr19T5bsSLBAJFGH534LZdtcNdMlzUsW1aO4QBF9dFJvr4lqddR/tz27Q1ed9lNs/z6vcf582rHP73QitEhCV5RAwAAAICUYaIGAAAAACnDRA0AAAAAUoaJGgAAAACkTEHNRMxsvqRlkuol1YUQhhVjUBVlFo2XHrWry7436nGXndJ9msu+E8512UZX0WCkUqqybhOq7dbVZSt22sJlb5/oGybEbNxjaTSfsv09Ltvmbyck2mafuzq4bJ0Jzydat1q15ZqNqdlha5e989/+3P3SzrdF1v57fKNnJtv39Z/1d9kN1x3kso2ueS6+gYZkP1tZl8aajTUOufT8MS677Ef7umxlXY/oNleO2dhl693TxPe+yGLH02XBKpe1f3RqOYaTeWms2Zao3XaQy26/8HfRZQe17+yy+tDgsnP2mOOypbt95bI9rznHZX1Hc51bDMXo+vjtEMKHRdgOUE7ULbKGmkXWULPIGmoWqcJbHwEAAAAgZQqdqAVJj5rZNDMbEVvAzEaY2VQzm7pK/K0ZpMJa65aaRQpRs8gaahZZwzUtUqfQtz7uHkJ418w2lPSYmb0WQnii8QIhhDGSxkhSF+sRCtwfUAxrrVtqFilEzSJrqFlkDde0SJ2CJmohhHfz/35gZvdL2kXSE2tfKz1qOnVy2RsX7xhd9pWjr0q4Vd8IYWWXlowKpZb1ul3Tm7/xN5RL0uHffcplvTvMd9mIbn9xWU3kxfYG+RuNY8s1tezsfxmbaLn/2maoy16e1i+6n7pF70TzalNtNdsS7529m8vOO+Uulx253hKXLQ8rXTZvVbxh1LiPd3fZkM4LXHZK17d9Nupql+03Pd48p+apGT6swgYjla7Zlfv6PhD3/eK3Lutdu47L9hr0oMt2fO6Y6H76TXrFZf6sVhpLB/rs9v93pctO+/HpLqt58qVSDCnTKl2zhaqf9brL/veDfaLLDu7sf3eOv/RAl33ez58vnzrZNyh5+Qx/DtzrzVOj+17v7mejOeJa/dZHM+tsZuuv/r+kf5M0s1gDA0qBukXWULPIGmoWWUPNIq0KeUVtI0n3W66dfTtJt4cQHinKqIDSoW6RNdQssoaaRdZQs0ilVk/UQghvSvpGEccClBx1i6yhZpE11CyyhppFWtGeHwAAAABSphh/8DqzYo1DkjcNSW6nA2a57KnNfcMEhfhN7olZpAFRC7bZb4Kft68z8YVkKweaHxXb+2f4JgrXn/17l32z0/To+quCb1DQ3mpd9vtPtnDZYx9u47I3H47cud6E/rf5hgsyX4s/mezv0/7lhv549t3SNwaQpNo20kykraj5hq+7kSff7bJY45BXV/nGIT8Y/58u63/R003svc4lc3b4V5f9+ue+YdSsPca57JE7ffMcSdrxct/YofdlTY0JrRXa+fNNrHFIUitXxC+XGr78stXbLFTPmf737oFP+wYOvxp7n8uuGXmEyzo9+HxxBobUeGb8TtH846PWdVmvh+e5rPsSf6794QM/cdnx9/3JZUPPmxbd99yJvpFfw1dfRZcFr6gBAAAAQOowUQMAAACAlGGiBgAAAAApw0QNAAAAAFKmzTQTqR28lctOOfDPZdn32P6Puaym/2SXNaihoP3URObdLdrmvpFtXuu3OfTyM1zGzfDFt6Knz3bs6L+fW0w5MfE2B9zob7DvMM3fQFy/9D2X9ZXPmuLbMkjaZXsX7dzpXZc1qPU3/CMb2m3SN5ofcZc/V/5o/Q9c9uCXXVw25sD9XNZ/TmHnpYYZr7ls4CVbu2zbnx/nsliDEUlq2P0zl9mV/ldxqIv+FCGhTiP9uSWpKct9s4M+f/RNZCqty+3PRjK/3OZv+YYQR106yWVX/Hif6H62GOVrtu6tSMMopM6GV8fPgZ9c3fptNsz058Wf3/dDl7129DXR9QdfcJrL+v/imdYPqMrxihoAAAAApAwTNQAAAABIGSZqAAAAAJAyTNQAAAAAIGWYqAEAAABAyrSZro9vHNnDZfd3n1OBkWTfhDN/47KD2p/nsr6j6QRZiE0v9l+/gy/e2WWb66WC9lNf0NrJ1a3nu6b1rvUdHl9a6TtbdnhvWXSb5Ro7iit0XS+axzo8xpwz8ccu23yO74BXCtFOkP+zjcumPxDv2vjyrre67OANDnBZ3eLkXVbhPbK172q4KiRbd86KPi7rOOmFQodUMWf+7EyXPX6F78h3wh43R9f/nwd8x97n9untsvoPP2rF6FANBo7yXRtP/dfdo8vef8xlLjvnusNcxjkwh1fUAAAAACBlmKgBAAAAQMowUQMAAACAlGl2omZmN5nZB2Y2s1HWw8weM7O5+X+7l3aYQMtQt8gaahZZQ80ia6hZZE2SZiLjJF0t6ZZG2ShJk0MIo81sVP7jnxV/eK1Tu9GGLvvlEbdXYCTVqU+7ji6b8B++wchRi891Wffx/obTEhmnjNVttXvrUH+6aZBvHNIQ2uwL/eNEzSbSfaZVeghf0/DybJctrPMNrCRpSIelLltw9ECX9flNJm6kH6eU1mx98OeWpPp3WOK3t/c+0WVrH3+x1fspl67T/fH815IhLrtog+nR9S/o9YrLRj7S2WVzT9zWZQ3TZyUZYjmNU0prtto88fCO0fzak55y2etnb+aygedl4hxYcs1eEYUQnpD08RrxcEnj8/8fL+mQIo8LKAh1i6yhZpE11CyyhppF1rT2qeuNQgiLJSn/r38JC0gf6hZZQ80ia6hZZA01i9Qq+d9RM7MRkkZIUietW+rdAQWjZpE11CyyhppFFlG3KLfWvqL2vpn1lqT8v03+ldIQwpgQwrAQwrD28vc2AWWUqG6pWaQINYusoWaRNVzTIrVa+4raREnHShqd/3dC0UZUBPNHbOGy4Z0nVWAkbccmkQYjg072N9gvGe+ickp13Va7Plv6G9prIs8VTfzM34BcP3tuQfuu3XaQy1ZsvL7LOs1932V1CxcVtO8CVWfNrlgZjV9Zucpl23do77K6Qz7xK99c8KharV3fPi7rVjMjumyt+Zr/fFD865FRqajZLSef6LI5+9yQaN39113msnNOro8uu9njLRpWRdS//obLnj9tqMtGXeV/1iRp9MYvuOyy3s+6bLvvD3PZgHh/krRJRc1Wm4F/eDOaP/bjdVzWd8jiUg8ns5K0579D0jOStjKzRWZ2gnLF/F0zmyvpu/mPgdSgbpE11CyyhppF1lCzyJpmX1ELIRzVxKfivWqBFKBukTXULLKGmkXWULPImjb7B4sAAAAAIK2YqAEAAABAypS8PX8lBPNZrGlBoe7/oofLxv74YL/g868Ufd8fnfQtn+0WuSE98sW4dI97ots8tPOafwNSam+1LlsVEgxQ0s39J7vsQPkbmNE2TNne112DGlx2x9RdXdb5Z/Gb3GMGHuBvYL6g3x0u27Gj3/dDX/R02ejRP4rup8dNzyQeE76uft5b0fywh8902dxDrnPZXUPGuuwHZ57rsj53+CY09Ut8U5tCLfr+AJft2akuumx95Py52d1FHhC06V3+dxdvbvsne8p3+Zh1wrbRZSfd87rLvrfuZ0UfE6pL3eL3ovnpDxzvsvO/94DL7u22lcvqP217dccragAAAACQMkzUAAAAACBlmKgBAAAAQMowUQMAAACAlKnKZiIWuVk71rSgUGMX7unDEjQOiel5g29k0POGZOuOm7J7NB8+aILLYo1DSvG1RHV58ze+2U2NXows6Z8rmnfAH1zWoHgHmxr5ZjnXfrqZyx5aOsRlJ721vcv6/No3IOjxPE1DymXr82e77Pid9nbZzZs+7rJpP7vaZd/+3r/7nVzn60OS1vv7Gy6r//Ajl7Xr389lu/zw5eg2Y74z61CXdfyb/72RsGcTmtBx0gsu23n0GS67ZeTlLhvcvoPL+m/om21JUrsBm7qsbv6CJENMnYbps6L5R3XruaxGS1026/hrXPaNz093Wd/RT7didKgWNSv97+3jurzrsquOO8xlG/9v26sdXlEDAAAAgJRhogYAAAAAKcNEDQAAAABShokaAAAAAKRMVTYTGXHkw5UeQmrUduvqsl6dPq/ASFCNFj+wTTSfvNNvXdagdSKZb0wzYuG/uuzJN7aI7mfAjf6m5A7T5rmsfqm/8b2P4jfOo3Ji36f5o3dx2chffOmyyzZ+3mVTtrvX78T3O5Ak/erD7Vy2YHkPl+3UZbrLTun6dnyjERcOfMhlJ155gsv6/NU/j9rldf/1aaoBBLyNfu8bEZx18BEue3Tb+1z28NYPRLc59DDfoKT35dlsJtKUVcE3WWqqwdOaXjrj9y47cPTQgseEDPO/tqOWb0hLJYlX1AAAAAAgdZioAQAAAEDKMFEDAAAAgJRhogYAAAAAKdPsRM3MbjKzD8xsZqPsYjN7x8ym5x8HlHaYQHLULLKIukXWULPIGmoWWZOk6+M4SVdLumWN/IoQwu+KPqIiOK3bGy7zveWqz/JDfHe0mlM/cNkNm95TjuFU0jhlrGYr6aMTvuWy40f6znQjus53WY1ejG4z1uHx/frlLjv04nNd1uOmZ1y2uV6K7iemPvGSqTNO1K2zzgTfzfH1v/lutrt+/zSX7XbKVJdd0fu56H4u7DUzmhfb3p1WuWzeQdf7BQ/y0eTlHV122RaDizGs1hqnNl6zhx77N5fdt+c3XNbn0Ox257zpNwe77IRLrq7ASIpinNp4zVZcwmaOtV8lbA9Z5Zp9RS2E8ISkj8swFqAoqFlkEXWLrKFmkTXULLKmkHvUTjezGfmXkbs3tZCZjTCzqWY2dZVWFLA7oGDULLKo2bqlZpEy1CyyhusDpFJrJ2rXSdpc0hBJiyVd1tSCIYQxIYRhIYRh7eXftgGUCTWLLEpUt9QsUoSaRdZwfYDUatVELYTwfgihPoTQIOkGSf7mKCBFqFlkEXWLrKFmkTXULNIsSTMRx8x6hxAW5z88VFJ57sJOaMTCvV02pt/jRd/Pvhv5m4PHn7Wfyza+8umC9lO70YYu6/XAVy67pf8Yl60KsdYKyefn7a02ss1k6+758hEu66p5ifddTGmv2VKINQlZedCnLntop9+6rHetbway14wfuOyJHe6O7rsh0r5n/2kjXNYn0jgE/9QW6zaJ+k8/c1nPG30tzbvXv4Nph1NOj25zxunFbY5w6Lx447iPl6/b6m0e0W9aq9ctl6zV7DpHLnPZ/tuf5LLzx46Prn9hrxkuO7SLb7I0aucTXVa7aInL6ha/F91PJW3wlwU+vKT84yiVrNVs1m23R7LrwIFj57usrshjyYJmJ2pmdoekvSX1MrNFki6StLeZDVGud8t8SSeXcIxAi1CzyCLqFllDzSJrqFlkTbMTtRDCUZF4bAnGAhQFNYssom6RNdQssoaaRdYU0vURAAAAAFACTNQAAAAAIGVa1Uwk7d4/qJPL7n+qh8uGd/6woP2c1n2Oy4455xWXTThlc5fVmm+2UB/i8+b1a99yWWzsqyLrx5o6tESscUhsm4vq/N8T6drRNzxBYdr128Rlcy7tFV32j9/0zRE2abfcZdEmH7/2TWS6PO9ru+Ydi+479hxQeLZbE8sCpWFd1nPZN4f75g8t8eCXXVx2xX/+0GXrPPpydP3OK1rfLOKRfttH0kWt3h6k+o/83z6ufdxnP70uftvSi2f/3mWDO/hLqwcfGOey3V7y78Lb8PQOLqubH2nmUUaLDu9f0f0j/WJN7xYes0V02Vv7XeGyB77YwGXhS3+90hbxihoAAAAApAwTNQAAAABIGSZqAAAAAJAyTNQAAAAAIGWqsplI/ZIlLnv8s21cNrzzk0Xf9/o1/kbgH3dZ6LKayBy50MYf5RJrHDL82vNc1vfSp8sxnOq1i28c8JM/TnDZwZ0/ia5+zae+ic1tl+3vsj43PdPq8TRoWnTRWC33v+1tl9Ul2zPQrC++v6vLtjl3psuu3yT5eX/Sl74ZybXHfd9lnZ563mWRPkwFq1tI45BK6Tvls2g+tOEMl11x2h9ctlenlS57esc7XHbA2ENctujjWBMZqfOjvj573pDwfB6xeORu0fyOMy6LpP5aB21DrHHIgIeWuuzBPr6hWU57l2zf4VOX3TbBT1HmTvQ1uu778bNtt1tb/7OQJryiBgAAAAApw0QNAAAAAFKGiRoAAAAApAwTNQAAAABImapsJhLz0hVDfPjb4jcTqTYnLPi2y+ZcN9hlfW+hcUixHTzucZ9FGofseJW/mV2KN+/osaiAG80v8K0/amTRZbe693SXbbnouVbvG21Xu94bu2zOyAEue/oI3/CgZ806ifcTaxxy9XE/cJk9NT3xNlE9wrRXo3nvSD+l0S8c47LO429MtJ+7B93lsi41naLLTh/mz8kvnL1Zov3E7LlurGmINKh9ssYhi+qWu+zw/z7XZT1VHU0e2qq3r93AZQ/2ecRlL6yIN/nYuaO/bnhl5SqXbbzOMp8d8WKSIUqS5k/d0mX1s+cmXj8teEUNAAAAAFKGiRoAAAAApAwTNQAAAABIGSZqAAAAAJAyzU7UzKyfmU0xs9lm9qqZnZXPe5jZY2Y2N/9v99IPF2geNYusoWaRRdQtsoaaRdYk6fpYJ2lkCOFFM1tf0jQze0zScZImhxBGm9koSaMk/ax0Qy1Ml9ufddn2O5zpslePvrocw1F7q3XZqniDnKJvc/Yq311Hkob/1XfqG/STqS7rlv6OTVVRsyO6zndZgxpc1unDeOHULXqn1fv+6IRvueyhnX4bGU+8q97A++M1hiZVRc3G1Hbr6rL5p/nOsU351kEzXDZp0+tdVh+SdXjcYtLJ0XzbXyxwmb1Hh8dmVG3dFqLmyZdcdtHAoYnWff3aXVz2naHxjpPXbvKEy3bo4Os4uWTdHSVpynLfifLnl/hriJ43p+56gZptgbDbN1z2552vcdlWd57jstpNvoxuc9Ye41x23BVnu2yjqwrtJp69Do8xzb6iFkJYHEJ4Mf//ZZJmS+orabik8fnFxks6pFSDBFqCmkXWULPIIuoWWUPNImtadI+amQ2QtKOk5yRtFEJYLOUKX9KGTawzwsymmtnUVVpR2GiBFqJmkTXULLKopXVLzaLSONciCxJP1MxsPUn3SvppCGFp0vVCCGNCCMNCCMPaq2Nrxgi0CjWLrKFmkUWtqVtqFpXEuRZZkWiiZmbtlSvo20II9+Xj982sd/7zvSV9UJohAi1HzSJ46/iXAAAJHUlEQVRrqFlkEXWLrKFmkSXNNhMxM5M0VtLsEMLljT41UdKxkkbn/51QkhGW0MCLXnTZ4OBvhpWkkw561GVndn+t1fuONfmINYqQpEV1/uX1U+cd6TIzv9E5r/d12TZXfhLdz6DZvnFIFlVLze5z6n+4bOFBvkbOGflQdP2G//TPw1zx1/1ctvV1vh46Hf6+y3rX+mYNg/4Ub8wwaEp11FK5pL1mVxyws8s+2LG9y7q+5evzpv+53GWD2k8paDz1kfPnk1/5X2e/POl4l231lG9OIkl1K3gbU0ulvW6zaNCpz7vsnS0HRpfd7tj49Uo5dHnDZz3S1zjEoWabVturp8vOu/VWl728spfLtvrdfJfN+9/ou0f1Rt1yl/V90Dc/q4uu3fYk6fq4u6SjJb1iZqtbYF2gXDHfbWYnSFog6fDSDBFoMWoWWUPNIouoW2QNNYtMaXaiFkL4uyRr4tP7FHc4QOGoWWQNNYssom6RNdQssqZFXR8BAAAAAKXHRA0AAAAAUibJPWpVK0RuHt/s/PjNsFNuHeqysUfs67JIPw91f83fYB8iL7zH1pWk9l/49Ts96G84jq0+SP4Gzfr4bpAy60zw3+NBkdubJ3XZLLr+yqFbuOzsqx9x2U7fe8tlu3T01RRrdrPtxYuj++Ym4OrS5+fzXDZ5wGSXLQ8rXbaOdSpo39NW+jPWsePPctmAy15xWbtl01zWxGkWSK36uW9G8wEXxnOgNUJf3/xjz07+t/n215/isn6Ln3ZZr67rRvfz7PL+Lqt76+0kQ2yTeEUNAAAAAFKGiRoAAAAApAwTNQAAAABIGSZqAAAAAJAybbqZSEvUz3rdZf0v8hlQbvVLl0bz2ikvuuyhwd1ddvmN33XZWwfc6LLjF+ztsrpFvlkNqs/8q7Zy2Za7+SypAZPi7WY6PjHTZSH49h+brvA3rvtWNwCApN44olui5Ta7ab7L6ocOdtkT298aXX+Lib4ZySD55mnI4RU1AAAAAEgZJmoAAAAAkDJM1AAAAAAgZZioAQAAAEDK0EwEaOM6z+3gslWh3mXP/ml7l20q39QB1Wf9u56NZMXfDw1BAKAyNrvgGZcdcMFOkSXf9dE7Pjugb2xdGoe0FK+oAQAAAEDKMFEDAAAAgJRhogYAAAAAKcNEDQAAAABSptlmImbWT9ItkjZW7l7vMSGEK83sYkknSVqSX/SCEMLDpRookBQ12zJ9L/UNQQ68dKjLaBxSOtQssoaaRRZRt8iaJF0f6ySNDCG8aGbrS5pmZo/lP3dFCOF3pRse0CrULLKGmkXWULPIIuoWmdLsRC2EsFjS4vz/l5nZbEl9Sz0woLWoWWQNNYusoWaRRdQtsqZF96iZ2QBJO0p6Lh+dbmYzzOwmM+vexDojzGyqmU1dpRUFDRZoKWoWWUPNImuoWWQRdYssSDxRM7P1JN0r6achhKWSrpO0uaQhyj07cVlsvRDCmBDCsBDCsPbqWIQhA8lQs8gaahZZQ80ii6hbZEWiiZqZtVeuoG8LIdwnSSGE90MI9SGEBkk3SNqldMMEWoaaRdZQs8gaahZZRN0iS5qdqJmZSRoraXYI4fJGee9Gix0qaWbxhwe0HDWLrKFmkTXULLKIukXWJOn6uLukoyW9YmbT89kFko4ysyGSgqT5kk4uyQiBlqNmkTXULLKGmkUWUbfIlCRdH/8uySKf4u9LIJWoWWQNNYusoWaRRdQtsqZFXR8BAAAAAKXHRA0AAAAAUoaJGgAAAACkDBM1AAAAAEgZJmoAAAAAkDJM1AAAAAAgZZioAQAAAEDKWAihfDszWyLp7fyHvSR9WLadl1Y1HYuU3uPpH0LYoJw7pGYzI63HQ80WTzUdi5Tu4ylr3VZxzUrVdTxpPpZKnmvT/HVpjWo6njQfS6KaLetE7Ws7NpsaQhhWkZ0XWTUdi1R9x1Ms1fR1qaZjkarveIqlmr4u1XQsUvUdT7FU29elmo6nmo6lmKrt61JNx1MNx8JbHwEAAAAgZZioAQAAAEDKVHKiNqaC+y62ajoWqfqOp1iq6etSTcciVd/xFEs1fV2q6Vik6jueYqm2r0s1HU81HUsxVdvXpZqOJ/PHUrF71AAAAAAAcbz1EQAAAABShokaAAAAAKRM2SdqZrafmc0xs3lmNqrc+y+Umd1kZh+Y2cxGWQ8ze8zM5ub/7V7JMSZlZv3MbIqZzTazV83srHyeyeMpFWo2PajZZKjZ9KBmk8ty3VZTzUrUbVJZrlmpuuq2Wmu2rBM1M6uVdI2k/SVtK+koM9u2nGMognGS9lsjGyVpcghhS0mT8x9nQZ2kkSGEbSR9U9Jp+e9HVo+n6KjZ1KFmm0HNpg41m0AV1O04VU/NStRts6qgZqXqqtuqrNlyv6K2i6R5IYQ3QwgrJd0paXiZx1CQEMITkj5eIx4uaXz+/+MlHVLWQbVSCGFxCOHF/P+XSZotqa8yejwlQs2mCDWbCDWbItRsYpmu22qqWYm6TSjTNStVV91Wa82We6LWV9LCRh8vymdZt1EIYbGUKxRJG1Z4PC1mZgMk7SjpOVXB8RQRNZtS1GyTqNmUombXqhrrtiq+x9Rtk6qxZqUq+B5XU82We6JmkYy/D1BhZraepHsl/TSEsLTS40kZajaFqNm1omZTiJptFnWbQtTtWlGzKVRtNVvuidoiSf0afbyJpHfLPIZSeN/MektS/t8PKjyexMysvXIFfVsI4b58nNnjKQFqNmWo2WZRsylDzSZSjXWb6e8xddusaqxZKcPf42qs2XJP1F6QtKWZbWZmHSQdKWlimcdQChMlHZv//7GSJlRwLImZmUkaK2l2COHyRp/K5PGUCDWbItRsItRsilCziVVj3Wb2e0zdJlKNNStl9HtctTUbQijrQ9IBkl6X9Iakn5d7/0UY/x2SFktapdyzKSdI6qlcJ5m5+X97VHqcCY9lD+Vepp8haXr+cUBWj6eEXydqNiUPajbx14maTcmDmm3R1yqzdVtNNZs/Huo22dcpszWbH3/V1G211qzlDw4AAAAAkBJl/4PXAAAAAIC1Y6IGAAAAACnDRA0AAAAAUoaJGgAAAACkDBM1AAAAAEgZJmoAAAAAkDJM1AAAAAAgZf4/3ubVFkcaL6AAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 1080x5400 with 5 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(15, wrong_count*15))\n",
    "for i, (img, preds, truth) in enumerate(wrong_classify):\n",
    "    img = img.reshape(28, 28).cpu().numpy()\n",
    "    plt.subplot(1, wrong_count, i+1)\n",
    "    plt.imshow(img)\n",
    "    plt.title('true:%i, pred:%i' % (truth, preds))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-03T07:12:44.915877Z",
     "start_time": "2019-11-03T07:12:44.913744Z"
    }
   },
   "source": [
    "### Save Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-05T03:36:41.076182Z",
     "start_time": "2019-11-05T03:36:41.019197Z"
    }
   },
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'cnn_cuda.pkl')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "oldHeight": 511.99678000000006,
   "position": {
    "height": "533.984px",
    "left": "411.875px",
    "right": "20px",
    "top": "161.984px",
    "width": "760.516px"
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "varInspector_section_display": "block",
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
